import torch
import torch.nn.functional as F

from ...loss.iou_loss import bbox_overlaps
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


class DynamicSoftLabelAssigner(BaseAssigner):
    """Computes matching between predictions and ground truth with
    dynamic soft label assignment.

    Args:
        topk (int): Select top-k predictions to calculate dynamic k
            best matchs for each gt. Default 13.
        iou_factor (float): The scale factor of iou cost. Default 3.0.
    """

    def __init__(self, topk=13, iou_factor=3.0):
        self.topk = topk
        self.iou_factor = iou_factor

    def assign(
        self,
        pred_scores,
        priors,
        decoded_bboxes,
        gt_bboxes,
        gt_labels,
    ):
        """Assign gt to priors with dynamic soft label assignment.
        Args:
            pred_scores (Tensor): Classification scores of one image,
                a 2D-Tensor with shape [num_priors, num_classes]
            priors (Tensor): All priors of one image, a 2D-Tensor with shape
                [num_priors, 4] in [cx, xy, stride_w, stride_y] format.
            decoded_bboxes (Tensor): Predicted bboxes, a 2D-Tensor with shape
                [num_priors, 4] in [tl_x, tl_y, br_x, br_y] format.
            gt_bboxes (Tensor): Ground truth bboxes of one image, a 2D-Tensor
                with shape [num_gts, 4] in [tl_x, tl_y, br_x, br_y] format.
            gt_labels (Tensor): Ground truth labels of one image, a Tensor
                with shape [num_gts].

        Returns:
            :obj:`AssignResult`: The assigned result.
        """
        INF = 100000000
        num_gt = gt_bboxes.size(0)
        num_bboxes = decoded_bboxes.size(0)

        # assign 0 by default，每个 cell(anchor) 的标签分配结果储存数组
        assigned_gt_inds = decoded_bboxes.new_full((num_bboxes, ), 0, dtype=torch.long)

        prior_center = priors[:, :2]  # image 坐标系下，cell 左上角坐标, shape = (num_priors, 2)
        # 如果结果全为正，说明 cell 的左上角在 gt 里面
        lt_ = prior_center[:, None] - gt_bboxes[:, :2]  # 所有cell左上角 与 gt左上角 的差值, shape = (num_priors, num_gt, 2)
        rb_ = gt_bboxes[:, 2:] - prior_center[:, None]  # gt右下角 与 所有cell左上角 的差值, shape = (num_priors, num_gt, 2)

        deltas = torch.cat([lt_, rb_], dim=-1)  # (num_priors, num_gt, 4)，坐标差值 (delta_x1, delta_y1, delta_x2. delta_y2)
        # 判断 每个cell左上角 是否在 gt 里面。 先挑出4个坐标差值的最小值，再看其是否大于0，如果是置为 True.
        is_in_gts = deltas.min(dim=-1).values > 0  # shape = (num_priors, num_gt)
        # 如果 cell的左上角 至少在一个 gt 里面，则为 True ，否则为 False.
        valid_mask = is_in_gts.sum(dim=1) > 0  # shape = (num_priors, )，得到 每个cell的左上角 与 所有gt 的关系数组

        valid_decoded_bbox = decoded_bboxes[valid_mask]  # 筛选，shape = (num_valid, 4)
        valid_pred_scores = pred_scores[valid_mask]  # 筛选, shape = (num_valid, num_classes)
        num_valid = valid_decoded_bbox.size(0)  # 得到符合条件的 bbox 个数

        if num_gt == 0 or num_bboxes == 0 or num_valid == 0:
            # No ground truth or boxes, return empty assignment
            max_overlaps = decoded_bboxes.new_zeros((num_bboxes,))
            if num_gt == 0:
                # No truth, assign everything to background
                assigned_gt_inds[:] = 0  # 如果没有 gt ，则全是背景，标签分配为 0
            if gt_labels is None:
                assigned_labels = None
            else:
                assigned_labels = decoded_bboxes.new_full(  # 背景的类别序号是 -1
                    (num_bboxes,), -1, dtype=torch.long
                )
            return AssignResult(
                num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
            )
        # shape = (num_valid, num_gt). IOU越大，匹配效果越好，我们需要 IOU 大的 bbox结果。
        pairwise_ious = bbox_overlaps(valid_decoded_bbox, gt_bboxes)  # 计算符合条件 bbox 与 gt 的 IOU值
        # 转为 IOU 损失，IOU越大（靠近1），损失越小
        iou_cost = -torch.log(pairwise_ious + 1e-7)

        gt_onehot_label = (
            F.one_hot(gt_labels.to(torch.int64), pred_scores.shape[-1])  # shape = (num_gts, num_classes)
            .float()
            .unsqueeze(0)  # shape = (1, num_gts, num_classes)
            .repeat(num_valid, 1, 1)  # shape = (num_valid, num_gts, num_classes)
        )
        # shape 变为 (num_valid, num_gt, num_classes)
        valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)
        # 沿用了 gfl 的思路，用 IOU 值做分类的 label
        soft_label = gt_onehot_label * pairwise_ious[..., None]
        scale_factor = soft_label - valid_pred_scores
        # 还是 gfl 的思路
        cls_cost = F.binary_cross_entropy(
            valid_pred_scores, soft_label, reduction="none"
        ) * scale_factor.abs().pow(2.0)

        cls_cost = cls_cost.sum(dim=-1)
        # shape = (num_valid, num_gt)。这个cost数组是分类损失与bbox损失的综合损失。
        cost_matrix = cls_cost + iou_cost * self.iou_factor  # IOU更重视，毕竟当前是标签分配阶段，IOU越大，标签与bbox越匹配
        # matched_pred_ious ：shape = (bbox_match_gt_num, )，获得了 bbox 与其匹配的 gt 的 IOU 值
        # matched_gt_inds ：shape = (bbox_match_gt_num, )，里面的元素是 bbox 匹配到的 gt 的列索引
        matched_pred_ious, matched_gt_inds = self.dynamic_k_matching(
            cost_matrix, pairwise_ious, num_gt, valid_mask
        )
        # convert to AssignResult format
        # dynamic_k_matching函数里对valid_mask的修改不会改变其内存地址，所以此处的valid_mask与函数里的是同一个
        assigned_gt_inds[valid_mask] = matched_gt_inds + 1  # 背景是0，前景从1开始算
        assigned_labels = assigned_gt_inds.new_full((num_bboxes,), -1)
        assigned_labels[valid_mask] = gt_labels[matched_gt_inds].long()
        max_overlaps = assigned_gt_inds.new_full(
            (num_bboxes,), -INF, dtype=torch.float32
        )
        max_overlaps[valid_mask] = matched_pred_ious
        return AssignResult(
            num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels
        )

    def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
        """Use sum of topk pred iou as dynamic k. Refer from OTA and YOLOX.

        Args:
            cost (Tensor): Cost matrix.  shape = (num_valid, num_gt)
            pairwise_ious (Tensor): Pairwise iou matrix.  shape = (num_valid, num_gt)
            num_gt (int): Number of gt.
            valid_mask (Tensor): Mask for valid bboxes.  shape = (num_priors, )
        """
        matching_matrix = torch.zeros_like(cost)  # shape = (num_valid, num_gt)
        # select candidate topk ious for dynamic-k calculation
        candidate_topk = min(self.topk, pairwise_ious.size(0))  # 两个数之间选个最小值，免得报错
        # 降序输出 每个gt与所有候选bbox的 前topk 个 IOU值。 shape = (candidate_topk, num_gt)
        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
        # calculate dynamic k for each gt. 先得到每个 gt 的前topk个IOU值之和，再取整，最后做截断。得到每个gt IOU之和的整数部分
        # shape = (num_gt, ) 这个数组的每个元素是对应gt可以与几个bbox做匹配，最小值为1是因为gt肯定至少有一个bbox与之匹配
        dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
        for gt_idx in range(num_gt):
            _, pos_idx = torch.topk(
                cost[:, gt_idx], k=dynamic_ks[gt_idx].item(), largest=False
            )  # 升序 动态K，选出损失最小的前 dynamic_k 个 bbox
            matching_matrix[:, gt_idx][pos_idx] = 1.0  # gt 与哪个bbox匹配，元素值置为1

        del topk_ious, dynamic_ks, pos_idx
        # shape = (num_valid, )
        prior_match_gt_mask = matching_matrix.sum(1) > 1  # 大于 1 说明存在某些 bbox 会与多个 gt 匹配。
        if prior_match_gt_mask.sum() > 0:  # 判断是否有 bbox 匹配到 多个gt 的情况
            # 下面几行的作用是 去除匹配多个 gt 的 bbox 情况，每个 bbox 只匹配一个 gt
            cost_min, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1)  # 选择损失最小的那个 gt 与 bbox 做匹配
            matching_matrix[prior_match_gt_mask, :] *= 0.0
            matching_matrix[prior_match_gt_mask, cost_argmin] = 1.0  # 除损失最小的gt外，其他都置为 0
        # 上面步骤结束后，matching_matrix 已经赋值结束，为 1 说明匹配到了，为 0 说明没有匹配到
        # get foreground mask inside box and center prior
        fg_mask_inboxes = matching_matrix.sum(1) > 0.0  # 大于 0 是前景，等于 0 是背景。 shape = (num_valid, )
        # valid_mask[valid_mask] 的 shape 与 fg_mask_inboxes 是一样的，都是 (num_valid, )
        # valid_mask[valid_mask] 里的 bbox 都是左上角在 gt 内部的，是个很粗糙的分配结果。
        # fg_mask_inboxes 才是最终的分配结果，所以要同步给 valid_mask。
        valid_mask[valid_mask.clone()] = fg_mask_inboxes  # 此处的赋值不会新建一块内存，所以此处的 valid_mask 与 114 行的一样
        # shape = (bbox_match_gt_num, ), 里面的元素是 bbox 匹配到的 gt 的列索引
        matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
        # 获得了 bbox 与其匹配的 gt 的 IOU 值
        matched_pred_ious = (matching_matrix * pairwise_ious).sum(1)[fg_mask_inboxes]

        return matched_pred_ious, matched_gt_inds
